#!/usr/bin/env python3

from http.server import SimpleHTTPRequestHandler, HTTPServer
import socket
import csv
import sys
import tempfile
import threading
import os
import gzip
import traceback
import urllib.request
import subprocess
import lzma


def is_ipv6(host):
    try:
        socket.inet_aton(host)
        return False
    except:
        return True


def get_local_port(host, ipv6):
    if ipv6:
        family = socket.AF_INET6
    else:
        family = socket.AF_INET

    with socket.socket(family) as fd:
        fd.bind((host, 0))
        return fd.getsockname()[1]


CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT_HTTP = os.environ.get("CLICKHOUSE_PORT_HTTP", "8123")

#####################################################################################
# This test starts an HTTP server and serves data to clickhouse url-engine based table.
# The main goal of this test is checking that compress methods are working.
# In order for it to work ip+port of http server (given below) should be
# accessible from clickhouse server.
#####################################################################################

# IP-address of this host accessible from the outside world. Get the first one
HTTP_SERVER_HOST = (
    subprocess.check_output(["hostname", "-i"]).decode("utf-8").strip().split()[0]
)
IS_IPV6 = is_ipv6(HTTP_SERVER_HOST)
HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6)

# IP address and port of the HTTP server started from this script.
HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT)
if IS_IPV6:
    HTTP_SERVER_URL_STR = (
        "http://"
        + f"[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}"
        + "/"
    )
else:
    HTTP_SERVER_URL_STR = (
        "http://" + f"{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}" + "/"
    )

# Because we need to check the content of file.csv we can create this content and avoid reading csv
CSV_DATA = "Hello, 1\nWorld, 2\nThis, 152\nis, 9283\ntesting, 2313213\ndata, 555\n"


# Choose compression method
# (Will change during test, need to check standard data sending, to make sure that nothing broke)
COMPRESS_METHOD = "none"
ADDING_ENDING = ""
ENDINGS = [".gz", ".xz"]
SEND_ENCODING = True


def get_ch_answer(query):
    host = CLICKHOUSE_HOST
    if IS_IPV6:
        host = f"[{host}]"

    url = os.environ.get(
        "CLICKHOUSE_URL",
        "http://{host}:{port}".format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP),
    )
    return urllib.request.urlopen(url, data=query.encode()).read().decode()


def check_answers(query, answer):
    ch_answer = get_ch_answer(query)
    if ch_answer.strip() != answer.strip():
        print("FAIL on query:", query, file=sys.stderr)
        print("Expected answer:", answer, file=sys.stderr)
        print("Fetched answer :", ch_answer, file=sys.stderr)
        raise Exception("Fail on query")


# Server with head method which is useful for debuging by hands
class HttpProcessor(SimpleHTTPRequestHandler):
    def _set_headers(self):
        self.send_response(200)
        if SEND_ENCODING:
            self.send_header("Content-Encoding", COMPRESS_METHOD)
        if COMPRESS_METHOD == "none":
            self.send_header("Content-Length", len(CSV_DATA.encode()))
        else:
            self.compress_data()
            self.send_header("Content-Length", len(self.data))
        self.send_header("Content-Type", "text/csv")
        self.end_headers()

    def do_HEAD(self):
        self._set_headers()
        return

    def compress_data(self):
        if COMPRESS_METHOD == "gzip":
            self.data = gzip.compress((CSV_DATA).encode())
        elif COMPRESS_METHOD == "lzma":
            self.data = lzma.compress((CSV_DATA).encode())
        else:
            self.data = "WRONG CONVERSATION".encode()

    def do_GET(self):
        self._set_headers()

        if COMPRESS_METHOD == "none":
            self.wfile.write(CSV_DATA.encode())
        else:
            self.wfile.write(self.data)
        return

    def log_message(self, format, *args):
        return


class HTTPServerV6(HTTPServer):
    address_family = socket.AF_INET6


def start_server(requests_amount):
    if IS_IPV6:
        httpd = HTTPServerV6(HTTP_SERVER_ADDRESS, HttpProcessor)
    else:
        httpd = HTTPServer(HTTP_SERVER_ADDRESS, HttpProcessor)

    def real_func():
        for i in range(requests_amount):
            httpd.handle_request()

    t = threading.Thread(target=real_func)
    return t


#####################################################################
# Testing area.
#####################################################################


def test_select(
    dict_name="",
    schema="word String, counter UInt32",
    requests=[],
    answers=[],
    test_data="",
):
    global ADDING_ENDING
    global SEND_ENCODING
    global COMPRESS_METHOD
    for i in range(len(requests)):
        if i > 2:
            ADDING_ENDING = ENDINGS[i - 3]
            SEND_ENCODING = False

        if dict_name:
            get_ch_answer("drop dictionary if exists {}".format(dict_name))
            get_ch_answer(
                """CREATE DICTIONARY {} ({})
            PRIMARY KEY word
            SOURCE(HTTP(url '{}' format 'CSV'))
            LAYOUT(complex_key_hashed())
            LIFETIME(0)""".format(
                    dict_name, schema, HTTP_SERVER_URL_STR + "/test.csv" + ADDING_ENDING
                )
            )

        COMPRESS_METHOD = requests[i]
        print(i, COMPRESS_METHOD, ADDING_ENDING, SEND_ENCODING)
        check_answers("SELECT * FROM {} ORDER BY word".format(dict_name), answers[i])


def main():
    # first three for encoding, second three for url
    insert_requests = ["none", "gzip", "lzma", "gzip", "lzma"]

    # This answers got experemently in non compressed mode and they are correct
    answers = ["""Hello	1\nThis	152\nWorld	2\ndata	555\nis	9283\ntesting	2313213"""] * 5

    t = start_server(len(insert_requests))
    t.start()
    test_select(
        dict_name="test_table_select", requests=insert_requests, answers=answers
    )
    t.join()
    print("PASSED")


if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback.print_tb(exc_traceback, file=sys.stderr)
        print(ex, file=sys.stderr)
        sys.stderr.flush()

        os._exit(1)
